{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b35d2e8f",
   "metadata": {},
   "source": [
    "# Comp1-传统组学\n",
    "\n",
    "主要适配于传统组学的建模和刻画。典型的应用场景探究rad_score最最终临床诊断的作用。\n",
    "\n",
    "数据的一般形式为(具体文件,文件夹名可以不同)：\n",
    "1. `images`文件夹，存放研究对象所有的CT、MRI等数据。\n",
    "2. `masks`文件夹, 存放手工（Manuelly）勾画的ROI区域。与images文件夹的文件意义对应。\n",
    "3. `label.txt`文件，每个患者对应的标签，例如肿瘤的良恶性、5年存活状态等。\n",
    "\n",
    "## Onekey步骤\n",
    "\n",
    "1. 数据校验，检查数据格式是否正确。\n",
    "2. 组学特征提取，如果第一步检查数据通过，则提取对应数据的特征。\n",
    "3. 读取标注数据信息。\n",
    "4. 特征与标注数据拼接。形成数据集。\n",
    "5. 查看一些统计信息，检查数据时候存在异常点。\n",
    "6. 正则化，将数据变化到服从 N~(0, 1)。\n",
    "7. 通过相关系数，例如spearman、person等筛选出特征。\n",
    "8. 构建训练集和测试集，这里使用的是随机划分，正常多中心验证，需要大家根据自己的场景构建两份数据。\n",
    "9. 通过Lasso筛选特征，选取其中的非0项作为后续模型的特征。\n",
    "10. 使用机器学习算法，例如LR、SVM、RF等进行任务学习。\n",
    "11. 模型结果可视化，例如AUC、ROC曲线，混淆矩阵等。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8747c587",
   "metadata": {},
   "outputs": [],
   "source": [
    "## 获得视频教程\n",
    "from onekey_algo.custom.Manager import onekey_show\n",
    "onekey_show('传统组学任务')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bb054795",
   "metadata": {},
   "source": [
    "## 一、数据校验\n",
    "首先需要检查诊断数据，如果显示`检查通过！`择可以正常运行之后的，否则请根据提示调整数据。\n",
    "\n",
    "**注意**：这里要求images和masks文件夹中的文件名必须一一对应。e.g. `1.nii.gz`为images中的一个文件，在masks文件夹必须也存在一个`1.nii.gz`文件。\n",
    "\n",
    "当然也可以使用自定义的函数，获取解析数据。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6e7bdf17",
   "metadata": {},
   "source": [
    "### 指定数据\n",
    "\n",
    "此模块有3个需要自己定义的参数\n",
    "\n",
    "1. `mydir`: 数据存放的路径。\n",
    "2. `labelf`: 每个样本的标注信息文件。\n",
    "3. `labels`: 要让AI系统学习的目标，例如肿瘤的良恶性、T-stage等。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3895d5fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 数据检验视频\n",
    "onekey_show('传统组学任务|数据检验')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3ddc852",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "plt.rcParams['figure.dpi'] = 300\n",
    "from onekey_algo import OnekeyDS as okds\n",
    "# 设置数据目录\n",
    "mydir = okds.ct\n",
    "# 对应的标签文件\n",
    "# labelf = r'你自己标注数据的文件地址'\n",
    "labelf = os.path.join(mydir, 'label_icc.csv')\n",
    "# 读取标签数据列名\n",
    "labels = ['label']\n",
    "\n",
    "os.makedirs('img', exist_ok=True)\n",
    "os.makedirs('results', exist_ok=True)\n",
    "os.makedirs('features', exist_ok=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1309e89",
   "metadata": {},
   "source": [
    "### images和masks匹配\n",
    "\n",
    "这里要求images和masks文件夹中的文件名必须一一对应。e.g. `1.nii.gz`为images中的一个文件，在masks文件夹必须也存在一个`1.nii.gz`文件。\n",
    "\n",
    "当然也可以使用自定义的函数，获取解析数据。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a4cbe42",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "from onekey_algo.custom.components.Radiology import diagnose_3d_image_mask_settings, get_image_mask_from_dir\n",
    "import re\n",
    "\n",
    "# 生成images和masks对，一对一的关系。也可以自定义替换。\n",
    "images, masks = get_image_mask_from_dir(mydir, images='images', masks='masks')\n",
    "\n",
    "# 自定义获取images和masks数据的方法，下面的例子为，每个样本一个文件夹，图像是以im.nii结尾，mask是以seg.nii结尾。\n",
    "# def get_images_mask(mydir):\n",
    "#     images = []\n",
    "#     masks = []\n",
    "#     for root, dirs, files in os.walk(mydir):\n",
    "#         for f in files:\n",
    "#             if f.endswith('_mask.nii'):\n",
    "#                 masks.append(os.path.join(root, f))\n",
    "#             else:\n",
    "#                 images.append(os.path.join(root, f))\n",
    "#     images = sorted(images, key=lambda x: int(re.split('[.|_]', os.path.basename(x))[0]))\n",
    "#     masks = sorted(masks, key=lambda x: int(re.split('[.|_]', os.path.basename(x))[0]))\n",
    "#     return images, masks\n",
    "# images, masks = get_images_mask(os.path.join(mydir, 'images'))\n",
    "\n",
    "# diagnose_3d_image_mask_settings(images, masks, verbose=True)\n",
    "print(f'获取到{len(images)}个样本。')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f4a8ee8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 特征提取视频\n",
    "onekey_show('传统组学任务|特征提取')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b3bb32d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "from onekey_algo.custom.components.Radiology import diagnose_3d_image_mask_settings, get_image_mask_from_dir\n",
    "\n",
    "# 生成images和masks对，一对一的关系。也可以自定义替换。\n",
    "images_icc1, masks_icc1 = get_image_mask_from_dir(os.path.join(mydir, 'ICC', 'ICC1'))\n",
    "images_icc2, masks_icc2 = get_image_mask_from_dir(os.path.join(mydir, 'ICC', 'ICC2'))\n",
    "\n",
    "diagnose_3d_image_mask_settings(images_icc1, masks_icc1)\n",
    "print(f'ICC标注 1，获取到{len(images_icc1)}个样本。')\n",
    "\n",
    "diagnose_3d_image_mask_settings(images_icc2, masks_icc2)\n",
    "print(f'ICC标注 2，获取到{len(images_icc2)}个样本。')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e929522",
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "import pandas as pd\n",
    " \n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "from onekey_algo.custom.components.Radiology import ConventionalRadiomics\n",
    "\n",
    "icc = []\n",
    "for idx, (images_i, masks_i) in enumerate(zip([images_icc1, images_icc2], [masks_icc1, masks_icc2])):\n",
    "    if os.path.exists(f'features/icc_{idx + 1}.csv'):\n",
    "        rad_data = pd.read_csv(f'features/icc_{idx + 1}.csv', header=0)\n",
    "    else:\n",
    "        # 如果要自定义一些特征提取方式，可以使用param_file。\n",
    "#         param_file = r'custom_settings/exampleCT.yaml'\n",
    "        param_file = None\n",
    "        radiomics = ConventionalRadiomics(param_file, correctMask= True)\n",
    "        radiomics.extract(images_i, masks_i)\n",
    "        rad_data = radiomics.get_label_data_frame(label=1)\n",
    "        rad_data.columns = [c.replace('-', '_') for c in rad_data.columns]\n",
    "        rad_data.to_csv(f'features/icc_{idx + 1}.csv', header=True, index=False)\n",
    "    icc.append(rad_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88fecc4e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from onekey_algo.custom.components.comp1 import icc_filter_features\n",
    "icc_sel_features = icc_filter_features(icc, columns_start=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "79b88fd9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "import pandas as pd\n",
    " \n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "from onekey_algo.custom.components.Radiology import ConventionalRadiomics\n",
    "\n",
    "if os.path.exists('features/rad_features.csv'):\n",
    "    rad_data = pd.read_csv('features/rad_features.csv', header=0)\n",
    "else:\n",
    "    # 如果要自定义一些特征提取方式，可以使用param_file。\n",
    "#     param_file = r'custom_settings/exampleCT.yaml'\n",
    "    param_file = None\n",
    "    radiomics = ConventionalRadiomics(param_file, correctMask= True)\n",
    "    radiomics.extract(images, masks)\n",
    "    rad_data = radiomics.get_label_data_frame(label=1)\n",
    "#     rad_data.columns = ['ID'] + [f\"CT_{c}\" for c in rad_data.columns[1:]]\n",
    "    rad_data.to_csv('features/rad_features.csv', header=True, index=False)\n",
    "rad_data = rad_data[['ID'] + icc_sel_features]\n",
    "rad_data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "59e2b71c",
   "metadata": {},
   "source": [
    "## 标注数据\n",
    "\n",
    "数据以csv格式进行存储，这里如果是其他格式，可以使用自定义函数读取出每个样本的结果。\n",
    "\n",
    "要求label_data为一个`DataFrame`格式，包括ID列以及后续的labels列，可以是多列，支持Multi-Task。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6503b076",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "label_data = pd.read_csv(labelf)\n",
    "label_data = label_data[['ID', 'group'] + labels]\n",
    "label_data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fe37e701",
   "metadata": {},
   "source": [
    "## 特征拼接 \n",
    "\n",
    "将标注数据`label_data`与`rad_data`进行合并，得到训练数据。\n",
    "\n",
    "**注意：需要删掉ID这一列**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1de1e4e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from onekey_algo.custom.utils import print_join_info\n",
    "\n",
    "# 删掉ID这一列。\n",
    "combined_data = pd.merge(rad_data, label_data, on=['ID'], how='inner')\n",
    "ids = combined_data['ID']\n",
    "combined_data = combined_data.drop(['ID'], axis=1)\n",
    "print_join_info(rad_data, label_data)\n",
    "print(combined_data[labels].value_counts())\n",
    "combined_data.columns"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b5df8daf",
   "metadata": {},
   "source": [
    "### 样本可视化\n",
    "\n",
    "根据特征和label信息，将rad features降维到2维，看不同的label样本在二维空间的分布。\n",
    "\n",
    "**注意**：由于特征空间维度极高，降维难免会有损失，所以二维的可视化仅供参考。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab59cb54",
   "metadata": {},
   "outputs": [],
   "source": [
    "from onekey_algo.custom.components.comp1 import analysis_features\n",
    "analysis_features(rad_data, combined_data[labels[0]])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2f8b1f6e",
   "metadata": {},
   "source": [
    "## 获取到数据的统计信息\n",
    "\n",
    "1. count，统计样本个数。\n",
    "2. mean、std, 对应特征的均值、方差\n",
    "3. min, 25%, 50%, 75%, max，对应特征的最小值，25,50,75分位数，最大值。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa5c306d",
   "metadata": {},
   "outputs": [],
   "source": [
    "combined_data.describe()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f0a13a21",
   "metadata": {},
   "source": [
    "## Z-score\n",
    "\n",
    "`normalize_df` 为onekey中正则化的API，将数据变化到0均值1方差。正则化的方法为\n",
    "\n",
    "$column = \\frac{column - mean}{std}$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33fe3d76",
   "metadata": {},
   "outputs": [],
   "source": [
    "from onekey_algo.custom.components.comp1 import normalize_df\n",
    "data = normalize_df(combined_data, not_norm=labels, group='group')\n",
    "data = data.dropna(axis=1)\n",
    "data.describe()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cd361f98",
   "metadata": {},
   "source": [
    "### 相关系数\n",
    "\n",
    "计算相关系数的方法有3种可供选择\n",
    "1. pearson （皮尔逊相关系数）: standard correlation coefficient\n",
    "\n",
    "2. kendall (肯德尔相关性系数) : Kendall Tau correlation coefficient\n",
    "\n",
    "3. spearman (斯皮尔曼相关性系数): Spearman rank correlation\n",
    "\n",
    "三种相关系数参考：https://blog.csdn.net/zmqsdu9001/article/details/82840332"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9598174b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 相关系数视频\n",
    "onekey_show('传统组学任务|相关系数')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "979a9aeb",
   "metadata": {},
   "outputs": [],
   "source": [
    "pearson_corr = data[data['group'] == 'train'][[c for c in data.columns if c not in labels]].corr('pearson')\n",
    "# kendall_corr = data[[c for c in data.columns if c not in labels]].corr('kendall')\n",
    "# spearman_corr = data[[c for c in data.columns if c not in labels]].corr('spearman')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e5f095c3",
   "metadata": {},
   "source": [
    "### 相关系数可视化\n",
    "\n",
    "通过修改变量名，可以可视化不同相关系数下的相关矩阵。\n",
    "\n",
    "**注意**：当特征特别多的时候（大于100），尽量不要可视化，否则运行时间会特别长。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01558bec",
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "from onekey_algo.custom.components.comp1 import draw_matrix\n",
    "plt.figure(figsize=(100.0, 80.0))\n",
    "\n",
    "# 选择可视化的相关系数\n",
    "draw_matrix(pearson_corr, annot=True, cmap='YlGnBu', cbar=False)\n",
    "plt.savefig(f'img/feature_corr.svg', bbox_inches = 'tight')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c8190131",
   "metadata": {},
   "source": [
    "### 聚类分析\n",
    "\n",
    "通过修改变量名，可以可视化不同相关系数下的相聚类分析矩阵。\n",
    "\n",
    "注意：当特征特别多的时候（大于100），尽量不要可视化，否则运行时间会特别长。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dcd6297e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "pp = sns.clustermap(pearson_corr, linewidths=.5, figsize=(50.0, 40.0), cmap='YlGnBu')\n",
    "plt.setp(pp.ax_heatmap.get_yticklabels(), rotation=0)\n",
    "plt.savefig(f'img/feature_cluster.svg', bbox_inches = 'tight')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3fcb5c23",
   "metadata": {},
   "source": [
    "### 特征筛选 -- 相关系数\n",
    "\n",
    "根据相关系数，对于相关性比较高的特征（一般文献取corr>0.9），两者保留其一。\n",
    "\n",
    "```python\n",
    "def select_feature(corr, threshold: float = 0.9, keep: int = 1, topn=10, verbose=False):\n",
    "    \"\"\"\n",
    "    * corr, 相关系数矩阵。\n",
    "    * threshold，筛选的相关系数的阈值，大于阈值的两者保留其一（可以根据keep修改，可以是其二...）。默认阈值为0.9\n",
    "    * keep，可以选择大于相关系数，保留几个，默认只保留一个。\n",
    "    * topn, 每次去掉多少重复特征。\n",
    "    * verbose，是否打印日志\n",
    "    \"\"\"\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43b9d6d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 特征筛选视频\n",
    "onekey_show('传统组学任务|特征筛选')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4fe3df7d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from onekey_algo.custom.components.comp1 import select_feature\n",
    "sel_feature = select_feature(pearson_corr, threshold=0.9, topn=32, verbose=False)\n",
    "\n",
    "sel_feature = sel_feature + labels + ['group']\n",
    "sel_feature"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e88423c1",
   "metadata": {},
   "source": [
    "### 过滤特征\n",
    "\n",
    "通过`sel_feature`过滤出筛选出来的特征。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "86e09b3f",
   "metadata": {},
   "outputs": [],
   "source": [
    "sel_data = data[sel_feature]\n",
    "sel_data.describe()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e8e31ba9",
   "metadata": {},
   "source": [
    "## 构建数据\n",
    "\n",
    "将样本的训练数据X与监督信息y分离出来，并且对训练数据进行划分，一般的划分原则为80%-20%"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a923e73",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import onekey_algo.custom.components as okcomp\n",
    "\n",
    "n_classes = 2\n",
    "train_data = sel_data[(sel_data['group'] == 'train')]\n",
    "train_ids = ids[train_data.index]\n",
    "train_data = train_data.reset_index()\n",
    "train_data = train_data.drop('index', axis=1)\n",
    "y_data = train_data[labels]\n",
    "X_data = train_data.drop(labels + ['group'], axis=1)\n",
    "\n",
    "test_data = sel_data[sel_data['group'] != 'train']\n",
    "test_ids = ids[test_data.index]\n",
    "test_data = test_data.reset_index()\n",
    "test_data = test_data.drop('index', axis=1)\n",
    "y_test_data = test_data[labels]\n",
    "X_test_data = test_data.drop(labels + ['group'], axis=1)\n",
    "\n",
    "y_all_data = sel_data[labels]\n",
    "X_all_data = sel_data.drop(labels + ['group'], axis=1)\n",
    "\n",
    "column_names = X_data.columns\n",
    "print(f\"训练集样本数：{X_data.shape}, 验证集样本数：{X_test_data.shape}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6548eaee",
   "metadata": {},
   "source": [
    "### Lasso\n",
    "\n",
    "初始化Lasso模型，alpha为惩罚系数。具体的参数文档可以参考：[文档](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.Lasso.html?highlight=lasso#sklearn.linear_model.Lasso)\n",
    "\n",
    "### 交叉验证\n",
    "\n",
    "不同Lambda下的，特征的的权重大小。\n",
    "```python\n",
    "def lasso_cv_coefs(X_data, y_data, points=50, column_names: List[str] = None, **kwargs):\n",
    "    \"\"\"\n",
    "\n",
    "    Args:\n",
    "        X_data: 训练数据\n",
    "        y_data: 监督数据\n",
    "        points: 打印多少个点。默认50\n",
    "        column_names: 列名，默认为None，当选择的数据很多的时候，建议不要添加此参数\n",
    "        **kwargs: 其他用于打印控制的参数。\n",
    "\n",
    "    \"\"\"\n",
    " ```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28efd09a",
   "metadata": {},
   "outputs": [],
   "source": [
    "alpha = okcomp.comp1.lasso_cv_coefs(X_data, y_data, column_names=None, alpha_logmin=-3)\n",
    "plt.savefig(f'img/feature_lasso.svg', bbox_inches = 'tight')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "72d1f11f",
   "metadata": {},
   "source": [
    "### 模型效能\n",
    "\n",
    "```python\n",
    "def lasso_cv_efficiency(X_data, y_data, points=50, **kwargs):\n",
    "    \"\"\"\n",
    "\n",
    "    Args:\n",
    "        Xdata: 训练数据\n",
    "        ydata: 测试数据\n",
    "        points: 打印的数据密度\n",
    "        **kwargs: 其他的图像样式\n",
    "            # 数据点标记, fmt=\"o\"\n",
    "            # 数据点大小, ms=3\n",
    "            # 数据点颜色, mfc=\"r\"\n",
    "            # 数据点边缘颜色, mec=\"r\"\n",
    "            # 误差棒颜色, ecolor=\"b\"\n",
    "            # 误差棒线宽, elinewidth=2\n",
    "            # 误差棒边界线长度, capsize=2\n",
    "            # 误差棒边界厚度, capthick=1\n",
    "    Returns:\n",
    "    \"\"\"\n",
    " ```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca62636c",
   "metadata": {},
   "outputs": [],
   "source": [
    "okcomp.comp1.lasso_cv_efficiency(X_data, y_data, points=50, alpha_logmin=-3)\n",
    "plt.savefig(f'img/feature_mse_label.svg', bbox_inches = 'tight')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4c02f30b",
   "metadata": {},
   "source": [
    "### 惩罚系数\n",
    "\n",
    "使用交叉验证的惩罚系数作为模型训练的基础。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f29840a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn import linear_model\n",
    "\n",
    "models = []\n",
    "for label in labels:\n",
    "    clf = linear_model.Lasso(alpha=alpha)\n",
    "    clf.fit(X_data, y_data[label])\n",
    "    models.append(clf)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d6bda574",
   "metadata": {},
   "source": [
    "### 特征筛选\n",
    "\n",
    "筛选出其中coef > 0的特征。并且打印出相应的公式。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e787ee17",
   "metadata": {},
   "outputs": [],
   "source": [
    "COEF_THRESHOLD = 1e-8 # 筛选的特征阈值\n",
    "scores = []\n",
    "selected_features = []\n",
    "for label, model in zip(labels, models):\n",
    "    feat_coef = [(feat_name, coef) for feat_name, coef in zip(column_names, model.coef_) \n",
    "                 if COEF_THRESHOLD is None or abs(coef) > COEF_THRESHOLD]\n",
    "    selected_features.append([feat for feat, _ in feat_coef])\n",
    "    formula = ' '.join([f\"{coef:+.6f} * {feat_name}\" for feat_name, coef in feat_coef])\n",
    "    score = f\"{label} = {model.intercept_} {'+' if formula[0] != '-' else ''} {formula}\"\n",
    "    scores.append(score)\n",
    "    \n",
    "print(scores[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eb884b8e",
   "metadata": {},
   "source": [
    "### 特征权重"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26a22e86",
   "metadata": {},
   "outputs": [],
   "source": [
    "feat_coef = sorted(feat_coef, key=lambda x: x[1])\n",
    "feat_coef_df = pd.DataFrame(feat_coef, columns=['feature_name', 'Coefficients'])\n",
    "feat_coef_df.plot(x='feature_name', y='Coefficients', kind='barh')\n",
    "\n",
    "plt.savefig(f'img/feature_weights.svg', bbox_inches = 'tight')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0184fdad",
   "metadata": {},
   "source": [
    "### 进一步筛选特征\n",
    "\n",
    "使用Lasso筛选出来的Coefficients比较高的特征作为训练数据。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10064550",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_data = X_data[selected_features[0]]\n",
    "X_test_data = X_test_data[selected_features[0]]\n",
    "X_data.columns"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "86b7e7da",
   "metadata": {},
   "source": [
    "## 模型筛选\n",
    "\n",
    "根据筛选出来的数据，做模型的初步选择。当前主要使用到的是Onekey中的\n",
    "\n",
    "1. SVM，支持向量机，引用参考。\n",
    "2. KNN，K紧邻，引用参考。\n",
    "3. Decision Tree，决策树，引用参考。\n",
    "4. Random Forests, 随机森林，引用参考。\n",
    "5. XGBoost, bosting方法。引用参考。\n",
    "6. LightGBM, bosting方法，引用参考。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fbab32fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_names = ['SVM', 'KNN', 'RandomForest', 'ExtraTrees', 'XGBoost', 'LightGBM', 'NaiveBayes', 'AdaBoost', 'GradientBoosting', 'LR', 'MLP']\n",
    "models = okcomp.comp1.create_clf_model(model_names)\n",
    "model_names = list(models.keys())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c369ea2d",
   "metadata": {},
   "source": [
    "### 交叉验证\n",
    "\n",
    "`n_trails`指定随机次数，每次采用的是80%训练，随机20%进行测试，找到最好的模型，以及对应的最好的数据划分。\n",
    "\n",
    "```python\n",
    "def get_bst_split(X_data: pd.DataFrame, y_data: pd.DataFrame,\n",
    "            models: dict, test_size=0.2, metric_fn=accuracy_score, n_trails=10,\n",
    "            cv: bool = False, shuffle: bool = False, metric_cut_off: float = None, random_state=None):\n",
    "    \"\"\"\n",
    "    寻找数据集中最好的数据划分。\n",
    "    Args:\n",
    "        X_data: 训练数据\n",
    "        y_data: 监督数据\n",
    "        models: 模型名称，Dict类型、\n",
    "        test_size: 测试集比例\n",
    "        metric_fn: 评价模型好坏的函数，默认准确率，可选roc_auc_score。\n",
    "        n_trails: 尝试多少次寻找最佳数据集划分。\n",
    "        cv: 是否是交叉验证，默认是False，当为True时，n_trails为交叉验证的n_fold\n",
    "        shuffle: 是否进行随机打乱\n",
    "        metric_cut_off: 当metric_fn的值达到多少时进行截断。\n",
    "        random_state: 随机种子\n",
    "\n",
    "    Returns: {'max_idx': max_idx, \"max_model\": max_model, \"max_metric\": max_metric, \"results\": results}\n",
    "\n",
    "    \"\"\"\n",
    "```\n",
    "\n",
    "**注意：这里采用了【挑数据】，如果想要严谨，请修改`n_trails=1`。**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7bedcf86",
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "from sklearn.metrics import accuracy_score, roc_auc_score\n",
    "\n",
    "# 随机使用n_trails次数据划分，找到最好的一次划分方法，并且保存在results中。\n",
    "results = okcomp.comp1.get_bst_split(X_data, y_data, models, test_size=0.5, metric_fn=roc_auc_score, n_trails=2, cv=True, random_state=0)\n",
    "_, (X_train_sel, X_test_sel, y_train_sel, y_test_sel) = results['results'][results['max_idx']]\n",
    "X_train_sel, X_test_sel, y_train_sel, y_test_sel = X_data, X_test_data, y_data, y_test_data\n",
    "trails, _ = zip(*results['results'])\n",
    "cv_results = pd.DataFrame(trails, columns=model_names)\n",
    "# 可视化每个模型在不同的数据划分中的效果。\n",
    "sns.boxplot(data=cv_results)\n",
    "plt.ylabel('AUC %')\n",
    "plt.xlabel('Model Nmae')\n",
    "plt.savefig(f'img/model_cv.svg', bbox_inches = 'tight')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2ac7440e",
   "metadata": {},
   "source": [
    "## 模型筛选\n",
    "\n",
    "使用最好的数据划分，进行后续的模型研究。\n",
    "\n",
    "**注意**: 一般情况下论文使用的是随机划分的数据，但也有些论文使用【刻意】筛选的数据划分。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7158f70a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import joblib\n",
    "from onekey_algo.custom.components.comp1 import plot_feature_importance, plot_learning_curve, smote_resample\n",
    "targets = []\n",
    "os.makedirs('models', exist_ok=True)\n",
    "for l in labels:\n",
    "    new_models = list(okcomp.comp1.create_clf_model(model_names).values())\n",
    "    for mn, m in zip(model_names, new_models):\n",
    "        X_train_smote, y_train_smote = X_train_sel, y_train_sel\n",
    "        # 取消下一行的注释可以使用Smote进行采样，解决样本不均衡的问题。\n",
    "#         X_train_smote, y_train_smote = smote_resample(X_train_sel, y_train_sel)\n",
    "        m.fit(X_train_smote, y_train_smote[l])\n",
    "        # 保存训练的模型\n",
    "        joblib.dump(m, f'models/{mn}_{l}.pkl') \n",
    "        # 输出模型特征重要性，只针对高级树模型有用\n",
    "        plot_feature_importance(m, selected_features[0], save_dir='img')\n",
    "        \n",
    "#         plot_learning_curve(m, X_train_sel, y_train_sel, title=f'Learning Curve {mn}')\n",
    "#         plt.savefig(f\"img/Rad_{mn}_learning_curve.svg\", bbox_inches='tight')\n",
    "        plt.show()\n",
    "    targets.append(new_models)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "88194089",
   "metadata": {},
   "source": [
    "## 预测结果\n",
    "\n",
    "* predictions，二维数据，每个label对应的每个模型的预测结果。\n",
    "* pred_scores，二维数据，每个label对应的每个模型的预测概率值。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba87adcf",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import accuracy_score\n",
    "from sklearn.preprocessing import OneHotEncoder\n",
    "from onekey_algo.custom.components.delong import calc_95_CI\n",
    "from onekey_algo.custom.components.metrics import analysis_pred_binary\n",
    "\n",
    "predictions = [[(model.predict(X_train_sel), model.predict(X_test_sel)) \n",
    "                for model in target] for label, target in zip(labels, targets)]\n",
    "pred_scores = [[(model.predict_proba(X_train_sel), model.predict_proba(X_test_sel)) \n",
    "                for model in target] for label, target in zip(labels, targets)]\n",
    "\n",
    "metric = []\n",
    "pred_sel_idx = []\n",
    "for label, prediction, scores in zip(labels, predictions, pred_scores):\n",
    "    pred_sel_idx_label = []\n",
    "    for mname, (train_pred, test_pred), (train_score, test_score) in zip(model_names, prediction, scores):\n",
    "        # 计算训练集指数\n",
    "        acc, auc, ci, tpr, tnr, ppv, npv, precision, recall, f1, thres = analysis_pred_binary(y_train_sel[label], \n",
    "                                                                                              train_score[:, 1])\n",
    "        ci = f\"{ci[0]:.4f} - {ci[1]:.4f}\"\n",
    "        metric.append((mname, acc, auc, ci, tpr, tnr, ppv, npv, precision, recall, f1, thres, f\"{label}-train\"))\n",
    "                 \n",
    "        # 计算验证集指标\n",
    "        acc, auc, ci, tpr, tnr, ppv, npv, precision, recall, f1, thres = analysis_pred_binary(y_test_sel[label], \n",
    "                                                                                              test_score[:, 1])\n",
    "        ci = f\"{ci[0]:.4f} - {ci[1]:.4f}\"\n",
    "        metric.append((mname, acc, auc, ci, tpr, tnr, ppv, npv, precision, recall, f1, thres, f\"{label}-test\"))\n",
    "        # 计算thres对应的sel idx\n",
    "        pred_sel_idx_label.append(np.logical_or(test_score[:, 0] >= thres, test_score[:, 1] >= thres))\n",
    "    \n",
    "    pred_sel_idx.append(pred_sel_idx_label)\n",
    "metric = pd.DataFrame(metric, index=None, columns=['model_name', 'Accuracy', 'AUC', '95% CI',\n",
    "                                                   'Sensitivity', 'Specificity', \n",
    "                                                   'PPV', 'NPV', 'Precision', 'Recall', 'F1',\n",
    "                                                   'Threshold', 'Task'])\n",
    "metric"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d50bc4d4",
   "metadata": {},
   "source": [
    "### 绘制曲线\n",
    "\n",
    "绘制的不同模型的准确率柱状图和折线图曲线。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71f86dcd",
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "\n",
    "plt.figure(figsize=(10, 10))\n",
    "plt.subplot(211)\n",
    "sns.barplot(x='model_name', y='Accuracy', data=metric, hue='Task')\n",
    "plt.subplot(212)\n",
    "sns.lineplot(x='model_name', y='Accuracy', data=metric, hue='Task')\n",
    "plt.savefig(f'img/model_acc.svg', bbox_inches = 'tight')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3f5585c2",
   "metadata": {},
   "source": [
    "### 绘制ROC曲线\n",
    "确定最好的模型，并且绘制曲线。\n",
    "\n",
    "```python\n",
    "def draw_roc(y_test, y_score, title='ROC', labels=None):\n",
    "```\n",
    "\n",
    "`sel_model = ['SVM', 'KNN']`参数为想要绘制的模型对应的参数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b610f5d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "sel_model = model_names\n",
    "\n",
    "for sm in sel_model:\n",
    "    if sm in model_names:\n",
    "        sel_model_idx = model_names.index(sm)\n",
    "    \n",
    "        # Plot all ROC curves\n",
    "        plt.figure(figsize=(8, 8))\n",
    "        for pred_score, label in zip(pred_scores, labels):\n",
    "            okcomp.comp1.draw_roc([np.array(y_train_sel[label]), np.array(y_test_sel[label])], \n",
    "                                  pred_score[sel_model_idx], \n",
    "                                  labels=['Train', 'Test'], title=f\"Model: {sm}\")\n",
    "            plt.savefig(f'img/model_{sm}_roc.svg', bbox_inches = 'tight')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "10899ae5",
   "metadata": {},
   "source": [
    "#### 汇总所有模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20e770a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "sel_model = model_names\n",
    "\n",
    "for pred_score, label in zip(pred_scores, labels):\n",
    "    pred_test_scores = []\n",
    "    for sm in sel_model:\n",
    "        if sm in model_names:\n",
    "            sel_model_idx = model_names.index(sm)\n",
    "            pred_test_scores.append(pred_score[sel_model_idx][1])\n",
    "    okcomp.comp1.draw_roc([np.array(y_test_sel[label])] * len(pred_test_scores), \n",
    "                          pred_test_scores, \n",
    "                          labels=sel_model, title=f\"Model AUC\")\n",
    "    plt.savefig(f'img/model_roc.svg', bbox_inches = 'tight')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "24e955e3",
   "metadata": {},
   "source": [
    "### DCA 决策曲线"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ab01b0b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from onekey_algo.custom.components.comp1 import plot_DCA\n",
    "\n",
    "for pred_score, label in zip(pred_scores, labels):\n",
    "    pred_test_scores = []\n",
    "    for sm in sel_model:\n",
    "        if sm in model_names:\n",
    "            sel_model_idx = model_names.index(sm)\n",
    "            okcomp.comp1.plot_DCA(pred_score[sel_model_idx][1][:,1], np.array(y_test_sel[label]),\n",
    "                                  title=f'Rad Model {sm} DCA')\n",
    "            plt.savefig(f'img/model_{sm}_dca.svg', bbox_inches = 'tight')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "96ae6a46",
   "metadata": {},
   "source": [
    "### 绘制混淆矩阵\n",
    "\n",
    "绘制混淆矩阵，[混淆矩阵解释](https://baike.baidu.com/item/%E6%B7%B7%E6%B7%86%E7%9F%A9%E9%98%B5/10087822?fr=aladdin)\n",
    "`sel_model = ['SVM', 'KNN']`参数为想要绘制的模型对应的参数。\n",
    "\n",
    "如果需要修改标签到名称的映射，修改`class_mapping={1:'1', 0:'0'}`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5012998",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 设置绘制参数\n",
    "sel_model = model_names\n",
    "c_matrix = {}\n",
    "\n",
    "for sm in sel_model:\n",
    "    if sm in model_names:\n",
    "        sel_model_idx = model_names.index(sm)\n",
    "        for idx, label in enumerate(labels):\n",
    "            cm = okcomp.comp1.calc_confusion_matrix(predictions[idx][sel_model_idx][-1], y_test_sel[label],\n",
    "#                                                     sel_idx = pred_sel_idx[idx][sel_model_idx],\n",
    "                                                    class_mapping={1:'1', 0:'0'}, num_classes=2)\n",
    "            c_matrix[label] = cm\n",
    "            plt.figure(figsize=(5, 4))\n",
    "            plt.title(f'Rad Model:{sm}')\n",
    "            okcomp.comp1.draw_matrix(cm, norm=True, annot=True, cmap='Blues')\n",
    "            plt.savefig(f'img/model_{sm}_cm.svg', bbox_inches = 'tight')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4f856e3f",
   "metadata": {},
   "source": [
    "### 样本预测直方图\n",
    "\n",
    "绘制每个样本的预测结果以及对应的真实结果, 图例中label=xx可以修改成自己类别的真实标签。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36c4be77",
   "metadata": {},
   "outputs": [],
   "source": [
    "sel_model = model_names\n",
    "c_matrix = {}\n",
    "\n",
    "for sm in sel_model:\n",
    "    if sm in model_names:\n",
    "        sel_model_idx = model_names.index(sm)\n",
    "        for idx, label in enumerate(labels):            \n",
    "            okcomp.comp1.draw_predict_score(pred_scores[idx][sel_model_idx][-1], y_test_sel[label])\n",
    "            plt.title(f'{sm} sample predict score')\n",
    "            plt.legend(labels=[\"label=0\",\"label=1\"],loc=\"lower right\") \n",
    "            plt.savefig(f'img/model_{sm}_sample_dis.svg', bbox_inches = 'tight')\n",
    "            plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1da92679",
   "metadata": {},
   "source": [
    "## 保存模型结果\n",
    "\n",
    "可以把模型预测的标签结果以及每个类别的概率都保存下来。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa78e5eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "\n",
    "os.makedirs('results', exist_ok=True)\n",
    "sel_model = sel_model\n",
    "\n",
    "for idx, label in enumerate(labels):\n",
    "    for sm in sel_model:\n",
    "        if sm in model_names:\n",
    "            sel_model_idx = model_names.index(sm)\n",
    "            target = targets[idx][sel_model_idx]\n",
    "            # 预测训练集和测试集数据。\n",
    "            train_indexes = np.reshape(np.array(train_ids), (-1, 1)).astype(str)\n",
    "            test_indexes = np.reshape(np.array(test_ids), (-1, 1)).astype(str)\n",
    "            y_train_pred_scores = target.predict_proba(X_train_sel)\n",
    "            y_test_pred_scores = target.predict_proba(X_test_sel)\n",
    "            columns = ['ID'] + [f\"{label}-{i}\"for i in range(y_test_pred_scores.shape[1])]\n",
    "            # 保存预测的训练集和测试集结果\n",
    "            result_train = pd.DataFrame(np.concatenate([train_indexes, y_train_pred_scores], axis=1), columns=columns)\n",
    "            result_train.to_csv(f'results/{sm}_Rad_train.csv', index=False)\n",
    "            result_test = pd.DataFrame(np.concatenate([test_indexes, y_test_pred_scores], axis=1), columns=columns)\n",
    "            result_test.to_csv(f'results/{sm}_Rad_test.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f964b08",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
